import tensorflow as tf
import pickle
import multiprocessing

from examples.SUBTRACTION.architectures import *
from examples.SUBTRACTION.data import create_joint_dataloader, create_dataloaders
from examples.SUBTRACTION.evaluate import *


def baseline_train(seed, epochs):
    conv = OperatorConvolutionsSmaller(size=5)
    rpn = ObjectDetectorRegressor(conv, N=4, simple=True)
    classifier = ObjectDetectorClassifier(conv, simple=True)

    D_mem = create_joint_dataloader("train", batch_size=2, data_size=100, curriculum=True, samedistr=False, regression=True)
    D_mem_val = create_joint_dataloader("val", batch_size=50, curriculum=True, samedistr=False, regression=True)
    D = create_joint_dataloader("train", batch_size=10, data_size=15000, samedistr=False, constraints=True)
    D_val = create_joint_dataloader("val", batch_size=50, samedistr=False, constraints=True)

    rpn.load_weights("examples/SUBTRACTION/networks/rpnOD_100s_50e_2b_seed{}".format(seed))
    classifier.load_weights("examples/SUBTRACTION/networks/classifierOD_100s_50e_2b_seed{}".format(seed))

    baseline = StructuredODBaseline(rpn, classifier, D_mem)
    optimiser = tf.keras.optimizers.Adam(learning_rate=1e-3)
    baseline.optimiser = optimiser
    baseline.train(D, epochs, val_data=D_val, eval_fns=[simple_evaluate_regression_IoU,
                                                        simple_evaluate_ODbaseline],
                   fn_args=[[rpn, D_mem_val, True], [baseline, D_mem_val]])

    D_test = create_joint_dataloader("test", batch_size=50, data_size=5000, curriculum=True, samedistr=False, regression=True)

    ious = simple_evaluate_regression_IoU(rpn, D_test, False)
    accs = simple_evaluate_ODbaseline(baseline, D_test)
    baseline.logger.log('test_iou', -1, ious)
    baseline.logger.log('test_acc', -1, accs)

    pickle.dump(baseline.logger,
                    open("examples/SUBTRACTION/results/ADDEDbaseline_diffdistr_15000s_{}e_10b_seed{}".format(epochs, seed), "wb"))


for i in range(10):
    p = multiprocessing.Process(target=baseline_train, args=(i, 30, ))
    p.start()
    p.join()
